-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[flang][cuda] Update device descriptor on data transfer #114838
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-flang-runtime Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesWhen the destination of the data transfer is a global we might need to sync the descriptor after the data transfer is done. This is the case when the data transfer is from host/device to device as reallocation might have happened and the descriptor on the device needs to take the new values written on the host. A new entry point is added Full diff: https://github.com/llvm/llvm-project/pull/114838.diff 4 Files Affected:
diff --git a/flang/include/flang/Runtime/CUDA/memory.h b/flang/include/flang/Runtime/CUDA/memory.h
index 6d2e0c0f15942b..51d6b8d4545f09 100644
--- a/flang/include/flang/Runtime/CUDA/memory.h
+++ b/flang/include/flang/Runtime/CUDA/memory.h
@@ -49,6 +49,10 @@ void RTDECL(CUFDataTransferPtrDesc)(void *dst, Descriptor *src,
void RTDECL(CUFDataTransferDescDesc)(Descriptor *dst, Descriptor *src,
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
+/// Data transfer from a descriptor to a global descriptor.
+void RTDECL(CUFDataTransferGlobalDescDesc)(Descriptor *dst, Descriptor *src,
+ unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
+
} // extern "C"
} // namespace Fortran::runtime::cuda
#endif // FORTRAN_RUNTIME_CUDA_MEMORY_H_
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 4050064ebe95d8..a28d0a562f2f0b 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -429,6 +429,16 @@ struct CUFFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
}
};
+static bool isDstGlobal(cuf::DataTransferOp op) {
+ if (auto declareOp = op.getDst().getDefiningOp<fir::DeclareOp>())
+ if (declareOp.getMemref().getDefiningOp<fir::AddrOfOp>())
+ return true;
+ if (auto declareOp = op.getDst().getDefiningOp<hlfir::DeclareOp>())
+ if (declareOp.getMemref().getDefiningOp<fir::AddrOfOp>())
+ return true;
+ return false;
+}
+
struct CUFDataTransferOpConversion
: public mlir::OpRewritePattern<cuf::DataTransferOp> {
using OpRewritePattern::OpRewritePattern;
@@ -522,8 +532,11 @@ struct CUFDataTransferOpConversion
mlir::isa<fir::BaseBoxType>(dstTy)) {
// Transfer between two descriptor.
mlir::func::FuncOp func =
- fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescDesc)>(
- loc, builder);
+ isDstGlobal(op)
+ ? fir::runtime::getRuntimeFunc<mkRTKey(
+ CUFDataTransferGlobalDescDesc)>(loc, builder)
+ : fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescDesc)>(
+ loc, builder);
auto fTy = func.getFunctionType();
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
diff --git a/flang/runtime/CUDA/memory.cpp b/flang/runtime/CUDA/memory.cpp
index daf1db684a3d2e..d8b79e9b88a68f 100644
--- a/flang/runtime/CUDA/memory.cpp
+++ b/flang/runtime/CUDA/memory.cpp
@@ -9,6 +9,7 @@
#include "flang/Runtime/CUDA/memory.h"
#include "../terminator.h"
#include "flang/Runtime/CUDA/common.h"
+#include "flang/Runtime/CUDA/descriptor.h"
#include "flang/Runtime/assign.h"
#include "cuda_runtime.h"
@@ -123,5 +124,18 @@ void RTDECL(CUFDataTransferDescDesc)(Descriptor *dstDesc, Descriptor *srcDesc,
Fortran::runtime::Assign(
*dstDesc, *srcDesc, terminator, MaybeReallocate, memmoveFct);
}
+
+void RTDECL(CUFDataTransferGlobalDescDesc)(Descriptor *dstDesc,
+ Descriptor *srcDesc, unsigned mode, const char *sourceFile,
+ int sourceLine) {
+ RTNAME(CUFDataTransferDescDesc)(
+ dstDesc, srcDesc, mode, sourceFile, sourceLine);
+ if ((mode == kHostToDevice) || (mode == kDeviceToDevice)) {
+ void *deviceAddr{
+ RTNAME(CUFGetDeviceAddress)((void *)dstDesc, sourceFile, sourceLine)};
+ RTNAME(CUFDescriptorSync)
+ ((Descriptor *)deviceAddr, srcDesc, sourceFile, sourceLine);
+ }
+}
}
} // namespace Fortran::runtime::cuda
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index cee3048e279cc7..a760650d143583 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -224,4 +224,29 @@ func.func @_QPsub9() {
// CHECK: %[[DST:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
// CHECK: %[[SRC:.*]] = fir.convert %[[LOCAL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
+
+fir.global @_QMmod1Ea {data_attr = #cuf.cuda<device>} : !fir.box<!fir.heap<!fir.array<?xi32>>> {
+ %c0 = arith.constant 0 : index
+ %0 = fir.zero_bits !fir.heap<!fir.array<?xi32>>
+ %1 = fir.shape %c0 : (index) -> !fir.shape<1>
+ %2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.heap<!fir.array<?xi32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xi32>>>
+ fir.has_value %2 : !fir.box<!fir.heap<!fir.array<?xi32>>>
+}
+
+func.func @_QQdesc_global() attributes {fir.bindc_name = "host_sub"} {
+ %0 = fir.address_of(@_QMmod1Ea) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+ %1:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMmod1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+ %2 = fir.address_of(@_QFEahost) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+ %3:2 = hlfir.declare %2 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFEahost"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+ cuf.data_transfer %3#0 to %1#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+ return
+}
+
+// CHECK-LABEL: func.func @_QQdesc_global() attributes {fir.bindc_name = "host_sub"}
+// CHECK: %[[GLOBAL_ADDRESS:.*]] = fir.address_of(@_QMmod1Ea) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+// CHECK: %[[GLOBAL_DECL:.*]]:2 = hlfir.declare %[[GLOBAL_ADDRESS]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMmod1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[GLOBAL_DECL:.*]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferGlobalDescDesc(%[[BOX_NONE]],{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
+
+
} // end of module
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
When the destination of the data transfer is a global we might need to sync the descriptor after the data transfer is done. This is the case when the data transfer is from host/device to device as reallocation might have happened and the descriptor on the device needs to take the new values written on the host. A new entry point is added `CUFDataTransferGlobalDescDesc` with the sync when needed.
When the destination of the data transfer is a global we might need to sync the descriptor after the data transfer is done. This is the case when the data transfer is from host/device to device as reallocation might have happened and the descriptor on the device needs to take the new values written on the host.
A new entry point is added
CUFDataTransferGlobalDescDescwith the sync when needed.